{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name=\"top\"></a><img src=\"images/chisel_1024.png\" alt=\"Chisel logo\" style=\"width:480px;\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Module 4.4: A FIRRTL Transform Example\n",
    "\n",
    "**Prev: [Common Pass Idioms](4.3_firrtl_common_idioms.ipynb)**<br>\n",
    "\n",
    "This AnalyzeCircuit Transform walks a `firrtl.ir.Circuit`, and records the number of add ops it finds, per module.\n",
    "\n",
    "## Setup\n",
    "\n",
    "Please run the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "val path = System.getProperty(\"user.dir\") + \"/source/load-ivy.sc\"\n",
    "interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "// Compiler Infrastructure\n",
    "\n",
    "// Firrtl IR classes\n",
    "\n",
    "// Map functions\n",
    "\n",
    "// Scala's mutable collections\n",
    "import scala.collection.mutable\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Counting Adders Per Module\n",
    "\n",
    "As described, earlier, a Firrtl circuit is represented using a tree representation:\n",
    "  - A Firrtl `Circuit` contains a sequence of `DefModule`s.\n",
    "  - A `DefModule` contains a sequence of `Port`s, and maybe a `Statement`.\n",
    "  - A `Statement` can contain other `Statement`s, or `Expression`s.\n",
    "  - A `Expression` can contain other `Expression`s.\n",
    "\n",
    "To visit all Firrtl IR nodes in a circuit, we write functions that recursively walk down this tree. To record statistics, we will pass along a `Ledger` class and use it when we come across an add op:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Ledger {\n",
    "  import firrtl.Utils\n",
    "  private var moduleName: Option[String] = None\n",
    "  private val modules = mutable.Set[String]()\n",
    "  private val moduleAddMap = mutable.Map[String, Int]()\n",
    "  def foundAdd(): Unit = moduleName match {\n",
    "    case None => sys.error(\"Module name not defined in Ledger!\")\n",
    "    case Some(name) => moduleAddMap(name) = moduleAddMap.getOrElse(name, 0) + 1\n",
    "  }\n",
    "  def getModuleName: String = moduleName match {\n",
    "    case None => Utils.error(\"Module name not defined in Ledger!\")\n",
    "    case Some(name) => name\n",
    "  }\n",
    "  def setModuleName(myName: String): Unit = {\n",
    "    modules += myName\n",
    "    moduleName = Some(myName)\n",
    "  }\n",
    "  def serialize: String = {\n",
    "    modules map { myName =>\n",
    "      s\"$myName => ${moduleAddMap.getOrElse(myName, 0)} add ops!\"\n",
    "    } mkString \"\\n\"\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's define a FIRRTL Transform that walks the circuit and updates our `Ledger` whenever it comes across an adder (`DoPrim` with op argument `Add`). Don't worry about `inputForm` or `outputForm` for now.\n",
    "\n",
    "Take some time to understand how `walkModule`, `walkStatement`, and `walkExpression` enable traversing all `DefModule`, `Statement`, and `Expression` nodes in the FIRRTL AST.\n",
    "\n",
    "Questions to answer:\n",
    "  - **Why doesn't walkModule call walkExpression?**\n",
    "  - **Why does walkExpression do a post-order traversal?**\n",
    "  - **Can you modify walkExpression to do a pre-order traversal of Expressions?**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class AnalyzeCircuit extends firrtl.Transform {\n",
    "  import firrtl._\n",
    "  import firrtl.ir._\n",
    "  import firrtl.Mappers._\n",
    "  import firrtl.Parser._\n",
    "  import firrtl.annotations._\n",
    "  import firrtl.PrimOps._\n",
    "    \n",
    "  // Requires the [[Circuit]] form to be \"low\"\n",
    "  def inputForm = LowForm\n",
    "  // Indicates the output [[Circuit]] form to be \"low\"\n",
    "  def outputForm = LowForm\n",
    "\n",
    "  // Called by [[Compiler]] to run your pass. [[CircuitState]] contains\n",
    "  // the circuit and its form, as well as other related data.\n",
    "  def execute(state: CircuitState): CircuitState = {\n",
    "    val ledger = new Ledger()\n",
    "    val circuit = state.circuit\n",
    "\n",
    "    // Execute the function walkModule(ledger) on every [[DefModule]] in\n",
    "    // circuit, returning a new [[Circuit]] with new [[Seq]] of [[DefModule]].\n",
    "    //   - \"higher order functions\" - using a function as an object\n",
    "    //   - \"function currying\" - partial argument notation\n",
    "    //   - \"infix notation\" - fancy function calling syntax\n",
    "    //   - \"map\" - classic functional programming concept\n",
    "    //   - discard the returned new [[Circuit]] because circuit is unmodified\n",
    "    circuit map walkModule(ledger)\n",
    "\n",
    "    // Print our ledger\n",
    "    println(ledger.serialize)\n",
    "\n",
    "    // Return an unchanged [[CircuitState]]\n",
    "    state\n",
    "  }\n",
    "\n",
    "  // Deeply visits every [[Statement]] in m.\n",
    "  def walkModule(ledger: Ledger)(m: DefModule): DefModule = {\n",
    "    // Set ledger to current module name\n",
    "    ledger.setModuleName(m.name)\n",
    "\n",
    "    // Execute the function walkStatement(ledger) on every [[Statement]] in m.\n",
    "    //   - return the new [[DefModule]] (in this case, its identical to m)\n",
    "    //   - if m does not contain [[Statement]], map returns m.\n",
    "    m map walkStatement(ledger)\n",
    "  }\n",
    "\n",
    "  // Deeply visits every [[Statement]] and [[Expression]] in s.\n",
    "  def walkStatement(ledger: Ledger)(s: Statement): Statement = {\n",
    "\n",
    "    // Execute the function walkExpression(ledger) on every [[Expression]] in s.\n",
    "    //   - discard the new [[Statement]] (in this case, its identical to s)\n",
    "    //   - if s does not contain [[Expression]], map returns s.\n",
    "    s map walkExpression(ledger)\n",
    "\n",
    "    // Execute the function walkStatement(ledger) on every [[Statement]] in s.\n",
    "    //   - return the new [[Statement]] (in this case, its identical to s)\n",
    "    //   - if s does not contain [[Statement]], map returns s.\n",
    "    s map walkStatement(ledger)\n",
    "  }\n",
    "\n",
    "  // Deeply visits every [[Expression]] in e.\n",
    "  //   - \"post-order traversal\" - handle e's children [[Expression]] before e\n",
    "  def walkExpression(ledger: Ledger)(e: Expression): Expression = {\n",
    "\n",
    "    // Execute the function walkExpression(ledger) on every [[Expression]] in e.\n",
    "    //   - return the new [[Expression]] (in this case, its identical to e)\n",
    "    //   - if s does not contain [[Expression]], map returns e.\n",
    "    val visited = e map walkExpression(ledger)\n",
    "\n",
    "    visited match {\n",
    "      // If e is an adder, increment our ledger and return e.\n",
    "      case DoPrim(Add, _, _, _) =>\n",
    "        ledger.foundAdd\n",
    "        e\n",
    "      // If e is not an adder, return e.\n",
    "      case notadd => notadd\n",
    "    }\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running our Transform\n",
    "\n",
    "Now that we've defined it, let's run it on a Chisel design! First, let's define a Chisel module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "// Chisel stuff\n",
    "import chisel3._\n",
    "import chisel3.Input // Technicality: avoid a conflict with _root_.almond.input.Input\n",
    "import chisel3.util._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AddMe(val nInputs: Int, val width: Int) extends Module {\n",
    "  val io = IO(new Bundle {\n",
    "    val in  = Input(Vec(nInputs, UInt(width.W)))\n",
    "    val out = Output(UInt(width.W))\n",
    "  })\n",
    "  io.out := io.in.reduce(_ +& _)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's elaborate it into FIRRTL AST syntax."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "val firrtlSerialization = chisel3.Driver.emit(() => new AddMe(8, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's compile our FIRRTL into Verilog, but include our custom transform into the compilation. Note that it prints out the number of add ops it found!\n",
    "\n",
    "**Note** (January 2021): The following line may be broken due to a [bug](https://github.com/freechipsproject/chisel-bootcamp/issues/129)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "val verilog = compileFIRRTL(firrtlSerialization, new firrtl.VerilogCompiler(), Seq(new AnalyzeCircuit()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `compileFIRRTL` function is defined only in this tutorial - in a future section, we will describe how the process of inserting customTransforms.\n",
    "\n",
    "That's it for this section!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Scala",
   "language": "scala",
   "name": "scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "mimetype": "text/x-scala",
   "name": "scala211",
   "nbconvert_exporter": "script",
   "pygments_lexer": "scala",
   "version": "2.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
